"""Bayesian p-value Posterior/Prior predictive plot."""
import numpy as np

from ..labels import BaseLabeller
from ..rcparams import rcParams
from ..utils import _var_names
from .plot_utils import default_grid, filter_plotters_list, get_plotting_function
from ..sel_utils import xarray_var_iter


def plot_bpv(
    data,
    kind="u_value",
    t_stat="median",
    bpv=True,
    plot_mean=True,
    reference="analytical",
    mse=False,
    n_ref=100,
    hdi_prob=0.94,
    color="C0",
    grid=None,
    figsize=None,
    textsize=None,
    labeller=None,
    data_pairs=None,
    var_names=None,
    filter_vars=None,
    coords=None,
    flatten=None,
    flatten_pp=None,
    ax=None,
    backend=None,
    plot_ref_kwargs=None,
    backend_kwargs=None,
    group="posterior",
    show=None,
):
    r"""Plot Bayesian p-value for observed data and Posterior/Prior predictive.

    Parameters
    ----------
    data : InferenceData
        :class:`arviz.InferenceData` object containing the observed and
        posterior/prior predictive data.
    kind : {"u_value", "p_value", "t_stat"}, default "u_value"
        Specify the kind of plot:

        * The ``kind="p_value"`` computes :math:`p := p(y* \leq y | y)`.
          This is the probability of the data y being larger or equal than the predicted data y*.
          The ideal value is 0.5 (half the predictions below and half above the data).
        * The ``kind="u_value"`` argument computes :math:`p_i := p(y_i* \leq y_i | y)`.
          i.e. like a p_value but per observation :math:`y_i`. This is also known as marginal
          p_value. The ideal distribution is uniform. This is similar to the LOO-PIT
          calculation/plot, the difference is than in LOO-pit plot we compute
          :math:`pi = p(y_i* r \leq y_i | y_{-i} )`, where :math:`y_{-i}`,
          is all other data except :math:`y_i`.
        * The ``kind="t_stat"`` argument computes :math:`:= p(T(y)* \leq T(y) | y)`
          where T is any test statistic. See ``t_stat`` argument below for details
          of available options.

    t_stat : str, float, or callable, default "median"
        Test statistics to compute from the observations and predictive distributions.
        Allowed strings are “mean”, “median” or “std”. Alternative a quantile can be passed
        as a float (or str) in the interval (0, 1). Finally a user defined function is also
        acepted, see examples section for details.
    bpv : bool, default True
        If True add the Bayesian p_value to the legend when ``kind = t_stat``.
    plot_mean : bool, default True
        Whether or not to plot the mean test statistic.
    reference : {"analytical", "samples", None}, default "analytical"
        How to compute the distributions used as reference for ``kind=u_values``
        or ``kind=p_values``. Use `None` to not plot any reference.
    mse : bool, default False
        Show scaled mean square error between uniform distribution and marginal p_value
        distribution.
    n_ref : int, default 100
        Number of reference distributions to sample when ``reference=samples``.
    hdi_prob : float, optional
        Probability for the highest density interval for the analytical reference distribution when
        ``kind=u_values``. Should be in the interval (0, 1]. Defaults to the
        rcParam ``stats.hdi_prob``. See :ref:`this section <common_hdi_prob>` for usage examples.
    color : str, optional
        Matplotlib color
    grid : tuple, optional
        Number of rows and columns. By default, the rows and columns are
        automatically inferred. See :ref:`this section <common_grid>` for usage examples.
    figsize : (float, float), optional
        Figure size. If None it will be defined automatically.
    textsize : float, optional
        Text size scaling factor for labels, titles and lines. If None it will be autoscaled based
        on `figsize`.
    data_pairs : dict, optional
        Dictionary containing relations between observed data and posterior/prior predictive data.
        Dictionary structure:

        - key = data var_name
        - value = posterior/prior predictive var_name

        For example, ``data_pairs = {'y' : 'y_hat'}``
        If None, it will assume that the observed data and the posterior/prior
        predictive data have the same variable name.
    Labeller : Labeller, optional
        Class providing the method ``make_pp_label`` to generate the labels in the plot titles.
        Read the :ref:`label_guide` for more details and usage examples.
    var_names : list of str, optional
        Variables to be plotted. If `None` all variable are plotted. Prefix the variables by ``~``
        when you want to exclude them from the plot. See the :ref:`this section <common_var_names>`
        for usage examples. See :ref:`this section <common_var_names>` for usage examples.
    filter_vars : {None, "like", "regex"}, default None
        If `None` (default), interpret `var_names` as the real variables names. If "like",
        interpret `var_names` as substrings of the real variables names. If "regex",
        interpret `var_names` as regular expressions on the real variables names. See
        :ref:`this section <common_filter_vars>` for usage examples.
    coords : dict, optional
        Dictionary mapping dimensions to selected coordinates to be plotted.
        Dimensions without a mapping specified will include all coordinates for
        that dimension. Defaults to including all coordinates for all
        dimensions if None. See :ref:`this section <common_coords>` for usage examples.
    flatten : list, optional
        List of dimensions to flatten in observed_data. Only flattens across the coordinates
        specified in the coords argument. Defaults to flattening all of the dimensions.
    flatten_pp : list, optional
        List of dimensions to flatten in posterior_predictive/prior_predictive. Only flattens
        across the coordinates specified in the coords argument. Defaults to flattening all
        of the dimensions. Dimensions should match flatten excluding dimensions for data_pairs
        parameters. If `flatten` is defined and `flatten_pp` is None, then ``flatten_pp=flatten``.
    legend : bool, default True
        Add legend to figure.
    ax : 2D array-like of matplotlib_axes or bokeh_figure, optional
        A 2D array of locations into which to plot the densities. If not supplied, ArviZ will create
        its own array of plot areas (and return it).
    backend : str, optional
        Select plotting backend {"matplotlib", "bokeh"}. Default "matplotlib".
    plot_ref_kwargs :  dict, optional
        Extra keyword arguments to control how reference is represented.
        Passed to :meth:`matplotlib.axes.Axes.plot` or
        :meth:`matplotlib.axes.Axes.axhspan` (when ``kind=u_value``
        and ``reference=analytical``).
    backend_kwargs : bool, optional
        These are kwargs specific to the backend being used, passed to
        :func:`matplotlib.pyplot.subplots` or :class:`bokeh.plotting.figure`.
        For additional documentation check the plotting method of the backend.
    group : {"posterior", "prior"}, default "posterior"
        Specifies which InferenceData group should be plotted. If "posterior", then the values
        in `posterior_predictive` group are compared to the ones in `observed_data`, if "prior" then
        the same comparison happens, but with the values in `prior_predictive` group.
    show : bool, optional
        Call backend show function.

    Returns
    -------
    axes : 2D ndarray of matplotlib_axes or bokeh_figure

    See Also
    --------
    plot_ppc : Plot for posterior/prior predictive checks.
    plot_loo_pit : Plot Leave-One-Out probability integral transformation (PIT) predictive checks.
    plot_dist_comparison : Plot to compare fitted and unfitted distributions.

    References
    ----------
    * Gelman et al. (2013) see http://www.stat.columbia.edu/~gelman/book/ pages 151-153 for details

    Examples
    --------
    Plot Bayesian p_values.

    .. plot::
        :context: close-figs

        >>> import arviz as az
        >>> data = az.load_arviz_data("regression1d")
        >>> az.plot_bpv(data, kind="p_value")

    Plot custom test statistic comparison.

    .. plot::
        :context: close-figs

        >>> import arviz as az
        >>> data = az.load_arviz_data("regression1d")
        >>> az.plot_bpv(data, kind="t_stat", t_stat=lambda x:np.percentile(x, q=50, axis=-1))
    """
    if group not in ("posterior", "prior"):
        raise TypeError("`group` argument must be either `posterior` or `prior`")

    for groups in (f"{group}_predictive", "observed_data"):
        if not hasattr(data, groups):
            raise TypeError(f'`data` argument must have the group "{groups}"')

    if kind.lower() not in ("t_stat", "u_value", "p_value"):
        raise TypeError("`kind` argument must be either `t_stat`, `u_value`, or `p_value`")

    if reference is not None and reference.lower() not in ("analytical", "samples"):
        raise TypeError("`reference` argument must be either `analytical`, `samples`, or `None`")

    if hdi_prob is None:
        hdi_prob = rcParams["stats.hdi_prob"]
    elif not 1 >= hdi_prob > 0:
        raise ValueError("The value of hdi_prob should be in the interval (0, 1]")

    if data_pairs is None:
        data_pairs = {}

    if labeller is None:
        labeller = BaseLabeller()

    if backend is None:
        backend = rcParams["plot.backend"]
    backend = backend.lower()

    observed = data.observed_data

    if group == "posterior":
        predictive_dataset = data.posterior_predictive
    elif group == "prior":
        predictive_dataset = data.prior_predictive

    if var_names is None:
        var_names = list(observed.data_vars)
    var_names = _var_names(var_names, observed, filter_vars)
    pp_var_names = [data_pairs.get(var, var) for var in var_names]
    pp_var_names = _var_names(pp_var_names, predictive_dataset, filter_vars)

    if flatten_pp is None:
        if flatten is None:
            flatten_pp = list(predictive_dataset.dims.keys())
        else:
            flatten_pp = flatten
    if flatten is None:
        flatten = list(observed.dims.keys())

    if coords is None:
        coords = {}

    total_pp_samples = predictive_dataset.sizes["chain"] * predictive_dataset.sizes["draw"]

    for key in coords.keys():
        coords[key] = np.where(np.in1d(observed[key], coords[key]))[0]

    obs_plotters = filter_plotters_list(
        list(
            xarray_var_iter(
                observed.isel(coords), skip_dims=set(flatten), var_names=var_names, combined=True
            )
        ),
        "plot_t_stats",
    )
    length_plotters = len(obs_plotters)
    pp_plotters = [
        tup
        for _, tup in zip(
            range(length_plotters),
            xarray_var_iter(
                predictive_dataset.isel(coords),
                var_names=pp_var_names,
                skip_dims=set(flatten_pp),
                combined=True,
            ),
        )
    ]
    rows, cols = default_grid(length_plotters, grid=grid)

    bpvplot_kwargs = dict(
        ax=ax,
        length_plotters=length_plotters,
        rows=rows,
        cols=cols,
        obs_plotters=obs_plotters,
        pp_plotters=pp_plotters,
        total_pp_samples=total_pp_samples,
        kind=kind,
        bpv=bpv,
        t_stat=t_stat,
        reference=reference,
        mse=mse,
        n_ref=n_ref,
        hdi_prob=hdi_prob,
        plot_mean=plot_mean,
        color=color,
        figsize=figsize,
        textsize=textsize,
        labeller=labeller,
        plot_ref_kwargs=plot_ref_kwargs,
        backend_kwargs=backend_kwargs,
        show=show,
    )

    # TODO: Add backend kwargs
    plot = get_plotting_function("plot_bpv", "bpvplot", backend)
    axes = plot(**bpvplot_kwargs)
    return axes
